{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tensor Joining\n",
    "\n",
    "This section demonstrates the following methods of joining tensors:\n",
    "\n",
    " - Stacking\n",
    " \n",
    " - Concatenation\n",
    "\n",
    "These operations take multiple inputs and produce the output by joining the input tensors.\n",
    "The difference between these methods is that concatenation joins the tensors along an existing axis, and stacking inserts a new axis.\n",
    "\n",
    "Stacking can be used, for example, to combine separate coordinates into vectors, or to combine color planes into color images.\n",
    "Concatenation can be used, for example, to join tiles into a larger image or appending lists."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Concatenation\n",
    "In this section we will show you how to concatenate along different axes. Since, in the following example, we will be concatenating the same tensors along different axes, these tensors must have identical shapes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nvidia.dali as dali\n",
    "import nvidia.dali.fn as fn\n",
    "import numpy as np\n",
    "\n",
    "np.random.seed(1234)\n",
    "\n",
    "arr = np.array(\n",
    "    [\n",
    "        [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],\n",
    "        [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]],\n",
    "    ]\n",
    ")\n",
    "\n",
    "src1 = dali.types.Constant(arr)\n",
    "src2 = dali.types.Constant(arr + 100)\n",
    "src3 = dali.types.Constant(arr + 200)\n",
    "\n",
    "pipe_cat = dali.pipeline.Pipeline(batch_size=1, num_threads=3, device_id=0)\n",
    "with pipe_cat:\n",
    "    cat_outer = fn.cat(src1, src2, src3, axis=0)\n",
    "    cat_middle = fn.cat(src1, src2, src3, axis=1)\n",
    "    cat_inner = fn.cat(src1, src2, src3, axis=2)\n",
    "    pipe_cat.set_outputs(cat_outer, cat_middle, cat_inner)\n",
    "\n",
    "pipe_cat.build()\n",
    "o = pipe_cat.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Concatenation along outer axis:\n[[[  1   2   3   4]\n  [  5   6   7   8]\n  [  9  10  11  12]]\n\n [[ 13  14  15  16]\n  [ 17  18  19  20]\n  [ 21  22  23  24]]\n\n [[101 102 103 104]\n  [105 106 107 108]\n  [109 110 111 112]]\n\n [[113 114 115 116]\n  [117 118 119 120]\n  [121 122 123 124]]\n\n [[201 202 203 204]\n  [205 206 207 208]\n  [209 210 211 212]]\n\n [[213 214 215 216]\n  [217 218 219 220]\n  [221 222 223 224]]]\nShape:  (6, 3, 4)\n"
     ]
    }
   ],
   "source": [
    "print(\"Concatenation along outer axis:\")\n",
    "print(o[0].at(0))\n",
    "print(\"Shape: \", o[0].at(0).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Concatenation along middle axis:\n[[[  1   2   3   4]\n  [  5   6   7   8]\n  [  9  10  11  12]\n  [101 102 103 104]\n  [105 106 107 108]\n  [109 110 111 112]\n  [201 202 203 204]\n  [205 206 207 208]\n  [209 210 211 212]]\n\n [[ 13  14  15  16]\n  [ 17  18  19  20]\n  [ 21  22  23  24]\n  [113 114 115 116]\n  [117 118 119 120]\n  [121 122 123 124]\n  [213 214 215 216]\n  [217 218 219 220]\n  [221 222 223 224]]]\nShape:  (2, 9, 4)\n"
     ]
    }
   ],
   "source": [
    "print(\"Concatenation along middle axis:\")\n",
    "print(o[1].at(0))\n",
    "print(\"Shape: \", o[1].at(0).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Concatenation along inner axis:\n[[[  1   2   3   4 101 102 103 104 201 202 203 204]\n  [  5   6   7   8 105 106 107 108 205 206 207 208]\n  [  9  10  11  12 109 110 111 112 209 210 211 212]]\n\n [[ 13  14  15  16 113 114 115 116 213 214 215 216]\n  [ 17  18  19  20 117 118 119 120 217 218 219 220]\n  [ 21  22  23  24 121 122 123 124 221 222 223 224]]]\nShape:  (2, 3, 12)\n"
     ]
    }
   ],
   "source": [
    "print(\"Concatenation along inner axis:\")\n",
    "print(o[2].at(0))\n",
    "print(\"Shape: \", o[2].at(0).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stacking\n",
    "\n",
    "When stacking, a new axis is inserted. It can be inserted _after_ the innermost axis, and in this case, the values from the input tensors are interleaved.\n",
    "\n",
    "Apply stacking to the same inputs that were used to concatenate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "pipe_stack = dali.pipeline.Pipeline(batch_size=1, num_threads=3, device_id=0)\n",
    "with pipe_stack:\n",
    "    st_outermost = fn.stack(src1, src2, src3, axis=0)\n",
    "    st_1 = fn.stack(src1, src2, src3, axis=1)\n",
    "    st_2 = fn.stack(src1, src2, src3, axis=2)\n",
    "    st_new_inner = fn.stack(src1, src2, src3, axis=3)\n",
    "    pipe_stack.set_outputs(st_outermost, st_1, st_2, st_new_inner)\n",
    "\n",
    "pipe_stack.build()\n",
    "o = pipe_stack.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Stacking - insert outermost axis:\n[[[[  1   2   3   4]\n   [  5   6   7   8]\n   [  9  10  11  12]]\n\n  [[ 13  14  15  16]\n   [ 17  18  19  20]\n   [ 21  22  23  24]]]\n\n\n [[[101 102 103 104]\n   [105 106 107 108]\n   [109 110 111 112]]\n\n  [[113 114 115 116]\n   [117 118 119 120]\n   [121 122 123 124]]]\n\n\n [[[201 202 203 204]\n   [205 206 207 208]\n   [209 210 211 212]]\n\n  [[213 214 215 216]\n   [217 218 219 220]\n   [221 222 223 224]]]]\nShape:  (3, 2, 3, 4)\n"
     ]
    }
   ],
   "source": [
    "print(\"Stacking - insert outermost axis:\")\n",
    "print(o[0].at(0))\n",
    "print(\"Shape: \", o[0].at(0).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Stacking - new axis before 1:\n[[[[  1   2   3   4]\n   [  5   6   7   8]\n   [  9  10  11  12]]\n\n  [[101 102 103 104]\n   [105 106 107 108]\n   [109 110 111 112]]\n\n  [[201 202 203 204]\n   [205 206 207 208]\n   [209 210 211 212]]]\n\n\n [[[ 13  14  15  16]\n   [ 17  18  19  20]\n   [ 21  22  23  24]]\n\n  [[113 114 115 116]\n   [117 118 119 120]\n   [121 122 123 124]]\n\n  [[213 214 215 216]\n   [217 218 219 220]\n   [221 222 223 224]]]]\nShape:  (2, 3, 3, 4)\n"
     ]
    }
   ],
   "source": [
    "print(\"Stacking - new axis before 1:\")\n",
    "print(o[1].at(0))\n",
    "print(\"Shape: \", o[1].at(0).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Stacking - new axis before 2:\n[[[[  1   2   3   4]\n   [101 102 103 104]\n   [201 202 203 204]]\n\n  [[  5   6   7   8]\n   [105 106 107 108]\n   [205 206 207 208]]\n\n  [[  9  10  11  12]\n   [109 110 111 112]\n   [209 210 211 212]]]\n\n\n [[[ 13  14  15  16]\n   [113 114 115 116]\n   [213 214 215 216]]\n\n  [[ 17  18  19  20]\n   [117 118 119 120]\n   [217 218 219 220]]\n\n  [[ 21  22  23  24]\n   [121 122 123 124]\n   [221 222 223 224]]]]\nShape:  (2, 3, 3, 4)\n"
     ]
    }
   ],
   "source": [
    "print(\"Stacking - new axis before 2:\")\n",
    "print(o[2].at(0))\n",
    "print(\"Shape: \", o[2].at(0).shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "Stacking - new innermost axis:\n[[[[  1 101 201]\n   [  2 102 202]\n   [  3 103 203]\n   [  4 104 204]]\n\n  [[  5 105 205]\n   [  6 106 206]\n   [  7 107 207]\n   [  8 108 208]]\n\n  [[  9 109 209]\n   [ 10 110 210]\n   [ 11 111 211]\n   [ 12 112 212]]]\n\n\n [[[ 13 113 213]\n   [ 14 114 214]\n   [ 15 115 215]\n   [ 16 116 216]]\n\n  [[ 17 117 217]\n   [ 18 118 218]\n   [ 19 119 219]\n   [ 20 120 220]]\n\n  [[ 21 121 221]\n   [ 22 122 222]\n   [ 23 123 223]\n   [ 24 124 224]]]]\nShape:  (2, 3, 4, 3)\n"
     ]
    }
   ],
   "source": [
    "print(\"Stacking - new innermost axis:\")\n",
    "print(o[3].at(0))\n",
    "print(\"Shape: \", o[3].at(0).shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}